{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matchzoo as mz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n",
    "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filter=True)\n",
    "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filter=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8442.14it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:04<00:00, 4635.69it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 734036.32it/s]\n",
      "Building FrequencyFilterUnit from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 126464.70it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 77951.57it/s] \n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 572171.57it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 653303.37it/s]\n",
      "Building VocabularyUnit from a datapack.: 100%|██████████| 404415/404415 [00:00<00:00, 2571197.34it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8680.29it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:04<00:00, 4626.89it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 111864.34it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 154171.84it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 115907.95it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 452595.06it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 673917.23it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 114512.50it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 77189.17it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=100, remove_stop_words=False)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 122/122 [00:00<00:00, 8015.93it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 1115/1115 [00:00<00:00, 4625.89it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 106743.56it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 124170.13it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 121562.97it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 139125.91it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 475801.09it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 68888.68it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 52856.63it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 237/237 [00:00<00:00, 9184.01it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2300/2300 [00:00<00:00, 4696.32it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 128198.00it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 137584.78it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 129007.18it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 222731.36it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 565203.84it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 99784.18it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 72382.33it/s]\n"
     ]
    }
   ],
   "source": [
    "valid_pack_processed = preprocessor.transform(valid_pack)\n",
    "predict_pack_processed = preprocessor.transform(predict_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankHingeLoss())\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter \"name\" set to DRMMTKS.\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 100)          0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           multiple             1667400     text_left[0][0]                  \n",
      "                                                                 text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dot_1 (Dot)                     (None, 10, 100)      0           embedding[0][0]                  \n",
      "                                                                 embedding[1][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 10, 1)        100         embedding[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "lambda_1 (Lambda)               (None, 10, 20)       0           dot_1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "attention_mask (Lambda)         (None, 10, 1)        0           dense_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 10, 5)        105         lambda_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "attention_probs (Lambda)        (None, 10, 1)        0           attention_mask[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 10, 1)        6           dense_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dot_2 (Dot)                     (None, 1, 1)         0           attention_probs[0][0]            \n",
      "                                                                 dense_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "flatten_1 (Flatten)             (None, 1)            0           dot_2[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "dense_4 (Dense)                 (None, 1)            2           flatten_1[0][0]                  \n",
      "==================================================================================================\n",
      "Total params: 1,667,613\n",
      "Trainable params: 1,667,613\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.DRMMTKS()\n",
    "model.params['input_shapes'] = preprocessor.context['input_shapes']\n",
    "model.params['task'] = ranking_task\n",
    "model.params['mask_value'] = -1\n",
    "model.params['embedding_input_dim'] = preprocessor.context['vocab_size']\n",
    "model.params['embedding_output_dim'] = 100\n",
    "model.params['embedding_trainable'] = True\n",
    "model.params['top_k'] = 20\n",
    "model.params['mlp_num_layers'] = 1\n",
    "model.params['mlp_num_units'] = 5\n",
    "model.params['mlp_num_fan_out'] = 1\n",
    "model.params['mlp_activation_func'] = 'relu'\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_embedding = mz.datasets.embeddings.load_glove_embedding(dimension=100)\n",
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])\n",
    "l2_norm = np.sqrt((embedding_matrix*embedding_matrix).sum(axis=1))\n",
    "embedding_matrix = embedding_matrix / l2_norm[:, np.newaxis]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_x, pred_y = predict_pack_processed[:].unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "102"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=2, num_neg=1, batch_size=20)\n",
    "len(train_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "102/102 [==============================] - 4s 42ms/step - loss: 0.9129\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.44913793360110954 - normalized_discounted_cumulative_gain@5(0): 0.5032644544569836 - mean_average_precision(0): 0.47069452226494585\n",
      "Epoch 2/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.7650\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5133645680060057 - normalized_discounted_cumulative_gain@5(0): 0.5712097375366855 - mean_average_precision(0): 0.534641819193907\n",
      "Epoch 3/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.6033\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.599709051518077 - normalized_discounted_cumulative_gain@5(0): 0.6483239888002686 - mean_average_precision(0): 0.6077730086244657\n",
      "Epoch 4/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.4298\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6266908483818189 - normalized_discounted_cumulative_gain@5(0): 0.6788882419934023 - mean_average_precision(0): 0.6380641803110157\n",
      "Epoch 5/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.3236\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6158753963293319 - normalized_discounted_cumulative_gain@5(0): 0.6796607353618087 - mean_average_precision(0): 0.6324264604979423\n",
      "Epoch 6/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.2393\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6269808134418329 - normalized_discounted_cumulative_gain@5(0): 0.6827464590772078 - mean_average_precision(0): 0.6402156298377446\n",
      "Epoch 7/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.1726\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6300851677408054 - normalized_discounted_cumulative_gain@5(0): 0.6903548664702603 - mean_average_precision(0): 0.6442976999008272\n",
      "Epoch 8/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.1257\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6320163678874755 - normalized_discounted_cumulative_gain@5(0): 0.6953126612811423 - mean_average_precision(0): 0.6512669261738508\n",
      "Epoch 9/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.0894\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6329108879315698 - normalized_discounted_cumulative_gain@5(0): 0.6905197789857627 - mean_average_precision(0): 0.6425124476805709\n",
      "Epoch 10/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0645\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6401061305937785 - normalized_discounted_cumulative_gain@5(0): 0.6922540467452605 - mean_average_precision(0): 0.6471011461914585\n",
      "Epoch 11/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0468\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6360357297906899 - normalized_discounted_cumulative_gain@5(0): 0.6862065097171752 - mean_average_precision(0): 0.6410820656470614\n",
      "Epoch 12/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0342\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6316776370399996 - normalized_discounted_cumulative_gain@5(0): 0.687222285852765 - mean_average_precision(0): 0.637158245127932\n",
      "Epoch 13/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0247\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6322300832576006 - normalized_discounted_cumulative_gain@5(0): 0.6868518406317019 - mean_average_precision(0): 0.6378616322237242\n",
      "Epoch 14/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0177\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6302704424205204 - normalized_discounted_cumulative_gain@5(0): 0.6868938682568098 - mean_average_precision(0): 0.633694128372676\n",
      "Epoch 15/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.0133\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6305090930389422 - normalized_discounted_cumulative_gain@5(0): 0.6886525267141036 - mean_average_precision(0): 0.6388478354251171\n",
      "Epoch 16/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0104\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6341157475172039 - normalized_discounted_cumulative_gain@5(0): 0.6904313973845297 - mean_average_precision(0): 0.6404569038063377\n",
      "Epoch 17/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.0081\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6341157475172039 - normalized_discounted_cumulative_gain@5(0): 0.6882659843545355 - mean_average_precision(0): 0.6381009147414877\n",
      "Epoch 18/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.0064\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6399776096743143 - normalized_discounted_cumulative_gain@5(0): 0.6916197043955237 - mean_average_precision(0): 0.6449520724420791\n",
      "Epoch 19/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0051\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6350569081574968 - normalized_discounted_cumulative_gain@5(0): 0.6864606431812963 - mean_average_precision(0): 0.6380027980877414\n",
      "Epoch 20/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.0042\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6349420552286718 - normalized_discounted_cumulative_gain@5(0): 0.6857270341115249 - mean_average_precision(0): 0.6390406022901026\n",
      "Epoch 21/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.0037\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6391614645113721 - normalized_discounted_cumulative_gain@5(0): 0.6915072006355713 - mean_average_precision(0): 0.6448192081686419\n",
      "Epoch 22/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0031\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6360617203636448 - normalized_discounted_cumulative_gain@5(0): 0.6917377225240787 - mean_average_precision(0): 0.6447306869358834\n",
      "Epoch 23/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0029\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6316536440377385 - normalized_discounted_cumulative_gain@5(0): 0.6891245863223207 - mean_average_precision(0): 0.6400098042086716\n",
      "Epoch 24/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0025\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6308374988747965 - normalized_discounted_cumulative_gain@5(0): 0.6889398972541687 - mean_average_precision(0): 0.6396379011468951\n",
      "Epoch 25/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0025\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6308374988747965 - normalized_discounted_cumulative_gain@5(0): 0.6888265199976148 - mean_average_precision(0): 0.63942693068276\n",
      "Epoch 26/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.0023\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6323947572985456 - normalized_discounted_cumulative_gain@5(0): 0.6870476646850574 - mean_average_precision(0): 0.6412442376632917\n",
      "Epoch 27/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0022\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6297326064395942 - normalized_discounted_cumulative_gain@5(0): 0.6843855138261061 - mean_average_precision(0): 0.637700629372848\n",
      "Epoch 28/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0022\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6302850526571954 - normalized_discounted_cumulative_gain@5(0): 0.684937960043707 - mean_average_precision(0): 0.6384624671600022\n",
      "Epoch 29/30\n",
      "102/102 [==============================] - 3s 33ms/step - loss: 0.0021\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6310261659180024 - normalized_discounted_cumulative_gain@5(0): 0.6846782390734201 - mean_average_precision(0): 0.6382485204853626\n",
      "Epoch 30/30\n",
      "102/102 [==============================] - 3s 32ms/step - loss: 0.0022\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6327971397118766 - normalized_discounted_cumulative_gain@5(0): 0.683184632215526 - mean_average_precision(0): 0.6379469818926847\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
